from typing import Optional, Callable

import numpy as np
import scipy

from ._koopman import TransferOperatorModel
from ..base import Model, Transformer, EstimatorTransformer
from ..kernels import Kernel
from ..numeric import sort_eigs, eigs
from ..util.types import to_dataset


class DMDModel(Model, Transformer):
    r""" Model produced by the :class:`DMD` estimator.

    Parameters
    ----------
    eigenvalues : (n,) ndarray
        The DMD eigenvalues.
    modes : (n, n) ndarray
        The DMD modes.
    mode : str, optional, default='exact'
        The mode of estimation that was used. See :attr:`DMD.available_modes`.
    """

    def __init__(self, eigenvalues: np.ndarray, modes: np.ndarray, mode='exact'):
        super().__init__()
        self.eigenvalues = eigenvalues
        self.modes = modes
        self.mode = mode

    def transform(self, data: np.ndarray, **kwargs):
        r""" Transforms an input trajectory by applying the model's captured dynamics.

        Parameters
        ----------
        data : (T, n) ndarray
            Input trajectory
        **kwargs
            Compatibility.

        Returns
        -------
        output_trajectory : (T, n) ndarray

        """
        if self.mode == 'exact':
            modes_adj = np.linalg.pinv(self.modes.T)
            return np.linalg.multi_dot([
                self.modes.T, np.diag(self.eigenvalues), modes_adj, data.T
            ]).T
        else:
            return np.linalg.multi_dot([
                self.modes.T, np.diag(self.eigenvalues), self.modes.conj(), data.T
            ]).T


class DMD(EstimatorTransformer):
    r""" Dynamic mode decomposition estimator. :footcite:`schmid2010dynamic`

    There are two supported modes:

    * `standard`, which produces "projected" DMD modes (following the original formulation of DMD),
    * `exact`, which produces DMD modes that do not required ordered data but just
      matched pairs of data :footcite:`tu2013dynamic`.

    Parameters
    ----------
    mode : str
        The estimation mode, see :attr:`available_modes` for available modes.
    rank : int or None, optional, default=None
        Truncation of the rank after performing SVD.
    driver : str, default='numpy'
        Which package to use for the SVD. Defaults to numpy, can also be 'scipy'.

    Notes
    -----
    In standard DMD, one considers a temporally ordered list of
    data vectors :math:`(z_0,\ldots,z_T)\in\mathbb{R}^{T\times d}`. The data is split into the pair

    .. math::

        X = (z_0, \ldots, z_{T-1}),\quad Y=(z_1,\ldots, z_T).

    If the mode is `exact`, the list does not need to be temporally ordered but just the pairs :math:`(X_i, Y_i)`
    have to match. The underlying assumption is that the data are generated by a linear relationship

    .. math::

        z_{t+1} = A z_t

    for some matrix :math:`A`.

    The so-called DMD modes and eigenvalues are then the (potentially scaled) eigenvectors and eigenvalues of :math:`A`.

    References
    ----------
    .. footbibliography::
    """

    available_modes = 'exact', 'standard'  #: The available estimation modes.
    available_drivers = 'numpy', 'scipy'  #: The available drivers.

    def __init__(self, mode='exact', rank=None, driver='scipy'):
        super().__init__()
        if driver not in DMD.available_drivers:
            raise ValueError(f"Invalid driver {driver}, must be one of {DMD.available_drivers}.")
        self.mode = mode
        self.rank = rank
        self.driver = driver

    @property
    def mode(self):
        return self._mode

    @mode.setter
    def mode(self, value):
        if value not in DMD.available_modes:
            raise ValueError(f"Invalid mode {value}, must be one of {DMD.available_modes}.")
        self._mode = value

    def _svd(self, mat, **kw):
        if self.driver == 'numpy':
            return np.linalg.svd(mat, **kw)
        elif self.driver == 'scipy':
            return scipy.linalg.svd(mat, **kw)

    def _eig(self, mat, **kw):
        if self.driver == 'numpy':
            return np.linalg.eig(mat, **kw)
        elif self.driver == 'scipy':
            return scipy.linalg.eig(mat, **kw)

    def fit(self, data, **kwargs):
        r""" Fit this estimator instance onto data.

        Parameters
        ----------
        data
            Input data, see :meth:`to_dataset <deeptime.util.types.to_dataset>` for options.
        **kwargs
            Kwargs, may contain lagtime.

        Returns
        -------
        self : DMD
            Reference to self.
        """
        dataset = to_dataset(data, lagtime=kwargs.get("lagtime", None))
        X, Y = dataset[:]
        X, Y = X.T, Y.T  # per convention arrays are [T, d] so here we transpose them

        U, s, Vt = self._svd(X, full_matrices=False)
        if self.rank is not None:
            rank = min(self.rank, U.shape[1])
            U = U[:, :rank]
            s = s[:rank]
            Vt = Vt[:rank]
        V = Vt.conj().T
        S_inv = np.diag(1 / s)
        A = np.linalg.multi_dot([U.conj().T, Y, V, S_inv])

        eigenvalues, eigenvectors = self._eig(A)
        eigenvalues, eigenvectors = sort_eigs(eigenvalues, eigenvectors, order='lexicographic')

        if self.mode == 'exact':
            dmd_modes = np.linalg.multi_dot([Y, V, S_inv, eigenvectors, np.diag(1 / eigenvalues)])
        elif self.mode == 'standard':
            dmd_modes = U @ eigenvectors

        self._model = DMDModel(eigenvalues, dmd_modes.T)
        return self

    def fetch_model(self) -> Optional[DMDModel]:
        r""" Yields the estimated model if :meth:`fit` was called.

        Returns
        -------
        model : DMDModel or None
            The model or None.
        """
        return self._model

    def transform(self, data, **kwargs):
        r""" See :meth:`DMDModel.transform`.

        Parameters
        ----------
        data : (T, d) np.ndarray
            Input data

        Returns
        -------
        result : (T, d) np.ndarray
            Propagated input data
        """
        return self.fetch_model().transform(data, **kwargs)


class EDMDModel(TransferOperatorModel):
    r""" The EDMD model which can be estimated from a :class:`EDMD` estimator. It possesses the estimated operator
    as well as capabilities to project onto modes.

    Parameters
    ----------
    operator : (n, n) ndarray
        The estimated operator.
    basis : callable
        The basis transform that was used.
    eigenvalues : (n, ) ndarray
        Eigenvalues for the modes.
    modes : (k, n) ndarray
        The EDMD modes.

    See Also
    --------
    EDMD
    """

    def __init__(self, operator: np.ndarray, basis: Callable[[np.ndarray], np.ndarray], eigenvalues, modes):
        super().__init__(operator, instantaneous_obs=basis, timelagged_obs=basis)
        self.basis = basis
        self.eigenvalues = eigenvalues
        self.modes = modes
        self.n_eigenvalues = len(self.eigenvalues)

    def transform(self, data: np.ndarray, **kw):
        r"""Transforms the data by first applying the basis and then the estimated modes, i.e.,

        .. math::

            X \mapsto \Psi(X)\varphi

        where :math:`X` is the input data, :math:`\Psi` the basis transform, and :math:`\varphi` the Koopman operator's

        Parameters
        ----------
        data : (T, n) ndarray
            Input data

        Returns
        -------
        transformed : (T, n) ndarray
            The forward transform of the input data.
        """
        modes = self.modes[:self.n_eigenvalues]
        out = self.instantaneous_obs(data)
        return out @ modes.T


class EDMD(EstimatorTransformer):
    r""" Extended dynamic mode decomposition for estimation of the Koopman (or optionally Perron-Frobenius)
    operator. :footcite:`williams2015data` :footcite:`klus2016edmd`.

    The estimator needs a basis :math:`\Psi : \mathbb{R}^n\to\mathbb{R}^k, \mathbf{x}\mapsto\Psi(\mathbf{x}))`
    and data matrices :math:`X = [x_1,\ldots,x_M]`, :math:`Y=[y_1,\ldots,y_M]` of time-lagged pairs of data.
    It then estimates a Koopman operator approximation :math:`K` so that :math:`\Psi(y_i)\approx K^\top \Psi(x_i)`.

    In other words, for data matrices :math:`\Psi_X` and :math:`\Psi_Y` it solves the minimization problem

    .. math::

        \min\| \Psi_Y - K\Psi_X\|_F.

    Parameters
    ----------
    basis : callable
        The basis callable, maps from (T, k) ndarray to (T, m) ndarray. See :mod:`deeptime.basis` for a selection
        of pre-defined bases.
    n_eigs : int, optional, default=None
        The number of eigenvalues, determining the number of dominant singular functions / modes being estimated.
        If None, estimates all eigenvalues / eigenvectors.
    operator : str, default='koopman'
        Which operator to estimate, see :attr:`available_operators`.

    References
    ----------
    .. footbibliography::
    """

    available_operators = 'koopman', 'perron-frobenius'  #: The supported operators.

    def __init__(self, basis: Callable[[np.ndarray], np.ndarray], n_eigs: Optional[int] = None,
                 operator: str = 'koopman'):
        super().__init__()
        self.basis = basis
        self.operator = operator
        self.n_eigs = n_eigs

    def fetch_model(self) -> Optional[EDMDModel]:
        r""" Yields the estimated model or None.

        Returns
        -------
        model : EDMDKoopmanModel or None
            The model.
        """
        return self._model

    def fit(self, data, **kwargs):
        r""" Fit this estimator instance onto data.

        Parameters
        ----------
        data
            Input data, see :meth:`to_dataset <deeptime.util.types.to_dataset>` for options.
        **kwargs
            Kwargs, may contain lagtime.

        Returns
        -------
        self : EDMD
            Reference to self.
        """
        dataset = to_dataset(data, lagtime=kwargs.get("lagtime", None))
        x, y = dataset[:]
        n_data = x.shape[0]
        assert n_data == y.shape[0], "Trajectories for data and timelagged data must be of same length!"
        psi_x = self.basis(x).T
        psi_y = self.basis(y).T

        cov_00 = (1/x.shape[0]) * psi_x @ psi_x.T
        cov_0t = (1/x.shape[0]) * psi_x @ psi_y.T

        if self.operator != 'koopman':
            cov_0t = cov_0t.T

        m_edmd = scipy.linalg.pinv(cov_00) @ cov_0t
        eig_val, eig_vec = eigs(m_edmd, self.n_eigs)
        eig_val, eig_vec = sort_eigs(eig_val, eig_vec, order='lexicographic')
        self._model = EDMDModel(m_edmd, basis=self.basis, eigenvalues=eig_val, modes=eig_vec.T)
        return self


class KernelEDMDModel(TransferOperatorModel):
    r""" The kEDMD model containing eigenvalues and eigenfunctions evaluated in the instantaneous data.

    Parameters
    ----------
    eigenvalues : (d,) ndarray
        The eigenvalues.
    eigenvectors : (T, d) ndarray
        The eigenfunction evaluation.
    kernel : Kernel
        The kernel that was used for estimation.

    See Also
    --------
    KernelEDMD
    """

    def __init__(self, data: np.ndarray, eigenvalues: np.ndarray, eigenvectors: np.ndarray, kernel: Kernel):
        super().__init__(eigenvectors @ np.diag(eigenvalues),
                         instantaneous_obs=lambda x: self.kernel.apply(x, self.data),
                         timelagged_obs=lambda x: self.kernel.apply(x, self.data))
        self.data = data
        self.eigenvalues = eigenvalues
        self.eigenvectors = eigenvectors
        self.kernel = kernel

    def transform(self, data: np.ndarray, propagate=True):
        return self.instantaneous_obs(data) @ self.eigenvectors


class KernelEDMD(EstimatorTransformer):
    r""" Estimator implementing kernel extended mode decomposition. :footcite:`williams2016kernel`
    :footcite:`klus2019eigendecomposition` :footcite:`klus2018kernel`.

    Parameters
    ----------
    kernel : Kernel
        The kernel to use. See :mod:`deeptime.kernels` for a list of available kernels.
    epsilon : float, optional, default=0
        Regularization parameter.
    n_eigs : int, optional, default=None
        Number of eigenvalue/eigenvector pairs to compute.

    References
    ----------
    .. footbibliography::
    """

    def __init__(self, kernel: Kernel, epsilon: float = 0., n_eigs: Optional[int] = None):
        super().__init__()
        self.kernel = kernel
        self.epsilon = epsilon
        self.n_eigs = n_eigs

    def fetch_model(self) -> Optional[KernelEDMDModel]:
        r""" Yields the estimated model or `None`.

        Returns
        -------
        model : KernelEDMDModel or None
            The model.
        """
        return super().fetch_model()

    def fit(self, data, **kwargs):
        r""" Fit this estimator instance onto data.

        Parameters
        ----------
        data
            Input data, see :meth:`to_dataset <deeptime.util.types.to_dataset>` for options.
        **kwargs
            Kwargs, may contain lagtime.

        Returns
        -------
        self : KernelEDMD
            Reference to self.
        """
        dataset = to_dataset(data, lagtime=kwargs.get("lagtime", None))
        x, y = dataset[:]
        gram_0 = self.kernel.gram(x)  # G_XX
        gram_1 = self.kernel.apply(x, y)  # G_XY

        A = scipy.linalg.solve(gram_0 + self.epsilon * np.eye(gram_0.shape[0]), gram_1.T, assume_a='sym')
        eigenvalues, eigenvectors = eigs(A, n_eigs=self.n_eigs)
        eigenvalues, eigenvectors = sort_eigs(eigenvalues, eigenvectors)
        self._model = KernelEDMDModel(x, eigenvalues, eigenvectors, self.kernel)
        return self
